{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对应 `tf.keras` 的01~02章节"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-21T13:46:26.504604Z",
     "start_time": "2025-02-21T13:46:26.452859Z"
    }
   },
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, torch:\n",
    "    print(module.__name__, module.__version__)\n",
    "    \n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sys.version_info(major=3, minor=10, micro=7, releaselevel='final', serial=0)\n",
      "matplotlib 3.10.0\n",
      "numpy 1.26.4\n",
      "pandas 2.0.3\n",
      "sklearn 1.6.1\n",
      "torch 2.6.0+cpu\n",
      "cpu\n"
     ]
    }
   ],
   "execution_count": 50
  },
  {
   "cell_type": "code",
   "source": [
    "28*28"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-02-21T13:46:26.787627Z",
     "start_time": "2025-02-21T13:46:26.773299Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "784"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 51
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据准备"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-21T13:46:26.953974Z",
     "start_time": "2025-02-21T13:46:26.800629Z"
    }
   },
   "source": [
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor\n",
    "from torchvision import transforms\n",
    "\n",
    "\n",
    "# 定义数据集的变换\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(), # 转换为tensor，进行归一化\n",
    "    # transforms.Normalize(mean, std) # 标准化，mean和std是数据集的均值和方差\n",
    "])\n",
    "# fashion_mnist图像分类数据集，衣服分类，60000张训练图片，10000张测试图片\n",
    "train_ds = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=transform\n",
    ")\n",
    "\n",
    "test_ds = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=transform\n",
    ")\n",
    "\n",
    "# torchvision 数据集里没有提供训练集和验证集的划分\n",
    "# 当然也可以用 torch.utils.data.Dataset 实现人为划分"
   ],
   "outputs": [],
   "execution_count": 52
  },
  {
   "cell_type": "code",
   "source": [
    "type(train_ds)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-02-21T13:46:26.972088Z",
     "start_time": "2025-02-21T13:46:26.959030Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torchvision.datasets.mnist.FashionMNIST"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 53
  },
  {
   "cell_type": "code",
   "source": "len(train_ds)",
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-02-21T13:46:26.986976Z",
     "start_time": "2025-02-21T13:46:26.976205Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "60000"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 54
  },
  {
   "cell_type": "code",
   "source": [
    "type(train_ds[0])"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-02-21T13:46:27.020382Z",
     "start_time": "2025-02-21T13:46:26.991973Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tuple"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 55
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-21T13:46:27.037264Z",
     "start_time": "2025-02-21T13:46:27.025172Z"
    }
   },
   "source": [
    "# 通过id取数据，取到的是一个元祖,是第一个样本,在训练时，把特征和标签分开\n",
    "img, label = train_ds[0]\n",
    "img.shape\n",
    "img.shape = (1, 28, 28) # 这是因为通道数在最前面"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 28, 28])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 56
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-21T13:49:40.690423Z",
     "start_time": "2025-02-21T13:49:40.662743Z"
    }
   },
   "cell_type": "code",
   "source": "type(img) #tensor中文是 张量,和numpy的ndarray类似",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Tensor"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 67
  },
  {
   "cell_type": "code",
   "source": "img[0]",
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-02-21T13:49:42.605699Z",
     "start_time": "2025-02-21T13:49:42.579746Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0510, 0.2863, 0.0000,\n",
       "         0.0000, 0.0039, 0.0157, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0039,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.1412, 0.5333, 0.4980, 0.2431,\n",
       "         0.2118, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118, 0.0157, 0.0000, 0.0000,\n",
       "         0.0118],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0235, 0.0000, 0.4000, 0.8000, 0.6902, 0.5255,\n",
       "         0.5647, 0.4824, 0.0902, 0.0000, 0.0000, 0.0000, 0.0000, 0.0471, 0.0392,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6078, 0.9255, 0.8118, 0.6980,\n",
       "         0.4196, 0.6118, 0.6314, 0.4275, 0.2510, 0.0902, 0.3020, 0.5098, 0.2824,\n",
       "         0.0588],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0039, 0.0000, 0.2706, 0.8118, 0.8745, 0.8549, 0.8471,\n",
       "         0.8471, 0.6392, 0.4980, 0.4745, 0.4784, 0.5725, 0.5529, 0.3451, 0.6745,\n",
       "         0.2588],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0039, 0.0039, 0.0039, 0.0000, 0.7843, 0.9098, 0.9098, 0.9137, 0.8980,\n",
       "         0.8745, 0.8745, 0.8431, 0.8353, 0.6431, 0.4980, 0.4824, 0.7686, 0.8980,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.7176, 0.8824, 0.8471, 0.8745, 0.8941,\n",
       "         0.9216, 0.8902, 0.8784, 0.8706, 0.8784, 0.8667, 0.8745, 0.9608, 0.6784,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.7569, 0.8941, 0.8549, 0.8353, 0.7765,\n",
       "         0.7059, 0.8314, 0.8235, 0.8275, 0.8353, 0.8745, 0.8627, 0.9529, 0.7922,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0039, 0.0118, 0.0000, 0.0471, 0.8588, 0.8627, 0.8314, 0.8549, 0.7529,\n",
       "         0.6627, 0.8902, 0.8157, 0.8549, 0.8784, 0.8314, 0.8863, 0.7725, 0.8196,\n",
       "         0.2039],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0235, 0.0000, 0.3882, 0.9569, 0.8706, 0.8627, 0.8549, 0.7961,\n",
       "         0.7765, 0.8667, 0.8431, 0.8353, 0.8706, 0.8627, 0.9608, 0.4667, 0.6549,\n",
       "         0.2196],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0157, 0.0000, 0.0000, 0.2157, 0.9255, 0.8941, 0.9020, 0.8941, 0.9412,\n",
       "         0.9098, 0.8353, 0.8549, 0.8745, 0.9176, 0.8510, 0.8510, 0.8196, 0.3608,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0039, 0.0157, 0.0235, 0.0275, 0.0078, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.9294, 0.8863, 0.8510, 0.8745, 0.8706, 0.8588,\n",
       "         0.8706, 0.8667, 0.8471, 0.8745, 0.8980, 0.8431, 0.8549, 1.0000, 0.3020,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.2431, 0.5686, 0.8000, 0.8941, 0.8118, 0.8353, 0.8667, 0.8549, 0.8157,\n",
       "         0.8275, 0.8549, 0.8784, 0.8745, 0.8588, 0.8431, 0.8784, 0.9569, 0.6235,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.1725, 0.3216, 0.4196, 0.7412,\n",
       "         0.8941, 0.8627, 0.8706, 0.8510, 0.8863, 0.7843, 0.8039, 0.8275, 0.9020,\n",
       "         0.8784, 0.9176, 0.6902, 0.7373, 0.9804, 0.9725, 0.9137, 0.9333, 0.8431,\n",
       "         0.0000],\n",
       "        [0.0000, 0.2235, 0.7333, 0.8157, 0.8784, 0.8667, 0.8784, 0.8157, 0.8000,\n",
       "         0.8392, 0.8157, 0.8196, 0.7843, 0.6235, 0.9608, 0.7569, 0.8078, 0.8745,\n",
       "         1.0000, 1.0000, 0.8667, 0.9176, 0.8667, 0.8275, 0.8627, 0.9098, 0.9647,\n",
       "         0.0000],\n",
       "        [0.0118, 0.7922, 0.8941, 0.8784, 0.8667, 0.8275, 0.8275, 0.8392, 0.8039,\n",
       "         0.8039, 0.8039, 0.8627, 0.9412, 0.3137, 0.5882, 1.0000, 0.8980, 0.8667,\n",
       "         0.7373, 0.6039, 0.7490, 0.8235, 0.8000, 0.8196, 0.8706, 0.8941, 0.8824,\n",
       "         0.0000],\n",
       "        [0.3843, 0.9137, 0.7765, 0.8235, 0.8706, 0.8980, 0.8980, 0.9176, 0.9765,\n",
       "         0.8627, 0.7608, 0.8431, 0.8510, 0.9451, 0.2549, 0.2863, 0.4157, 0.4588,\n",
       "         0.6588, 0.8588, 0.8667, 0.8431, 0.8510, 0.8745, 0.8745, 0.8784, 0.8980,\n",
       "         0.1137],\n",
       "        [0.2941, 0.8000, 0.8314, 0.8000, 0.7569, 0.8039, 0.8275, 0.8824, 0.8471,\n",
       "         0.7255, 0.7725, 0.8078, 0.7765, 0.8353, 0.9412, 0.7647, 0.8902, 0.9608,\n",
       "         0.9373, 0.8745, 0.8549, 0.8314, 0.8196, 0.8706, 0.8627, 0.8667, 0.9020,\n",
       "         0.2627],\n",
       "        [0.1882, 0.7961, 0.7176, 0.7608, 0.8353, 0.7725, 0.7255, 0.7451, 0.7608,\n",
       "         0.7529, 0.7922, 0.8392, 0.8588, 0.8667, 0.8627, 0.9255, 0.8824, 0.8471,\n",
       "         0.7804, 0.8078, 0.7294, 0.7098, 0.6941, 0.6745, 0.7098, 0.8039, 0.8078,\n",
       "         0.4510],\n",
       "        [0.0000, 0.4784, 0.8588, 0.7569, 0.7020, 0.6706, 0.7176, 0.7686, 0.8000,\n",
       "         0.8235, 0.8353, 0.8118, 0.8275, 0.8235, 0.7843, 0.7686, 0.7608, 0.7490,\n",
       "         0.7647, 0.7490, 0.7765, 0.7529, 0.6902, 0.6118, 0.6549, 0.6941, 0.8235,\n",
       "         0.3608],\n",
       "        [0.0000, 0.0000, 0.2902, 0.7412, 0.8314, 0.7490, 0.6863, 0.6745, 0.6863,\n",
       "         0.7098, 0.7255, 0.7373, 0.7412, 0.7373, 0.7569, 0.7765, 0.8000, 0.8196,\n",
       "         0.8235, 0.8235, 0.8275, 0.7373, 0.7373, 0.7608, 0.7529, 0.8471, 0.6667,\n",
       "         0.0000],\n",
       "        [0.0078, 0.0000, 0.0000, 0.0000, 0.2588, 0.7843, 0.8706, 0.9294, 0.9373,\n",
       "         0.9490, 0.9647, 0.9529, 0.9569, 0.8667, 0.8627, 0.7569, 0.7490, 0.7020,\n",
       "         0.7137, 0.7137, 0.7098, 0.6902, 0.6510, 0.6588, 0.3882, 0.2275, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1569, 0.2392,\n",
       "         0.1725, 0.2824, 0.1608, 0.1373, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000]])"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 68
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-21T13:49:45.148979Z",
     "start_time": "2025-02-21T13:49:45.123269Z"
    }
   },
   "cell_type": "code",
   "source": "img",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0510,\n",
       "          0.2863, 0.0000, 0.0000, 0.0039, 0.0157, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0039, 0.0039, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.1412, 0.5333,\n",
       "          0.4980, 0.2431, 0.2118, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118,\n",
       "          0.0157, 0.0000, 0.0000, 0.0118],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0000, 0.4000, 0.8000,\n",
       "          0.6902, 0.5255, 0.5647, 0.4824, 0.0902, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0471, 0.0392, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6078, 0.9255,\n",
       "          0.8118, 0.6980, 0.4196, 0.6118, 0.6314, 0.4275, 0.2510, 0.0902,\n",
       "          0.3020, 0.5098, 0.2824, 0.0588],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.2706, 0.8118, 0.8745,\n",
       "          0.8549, 0.8471, 0.8471, 0.6392, 0.4980, 0.4745, 0.4784, 0.5725,\n",
       "          0.5529, 0.3451, 0.6745, 0.2588],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0039, 0.0039, 0.0039, 0.0000, 0.7843, 0.9098, 0.9098,\n",
       "          0.9137, 0.8980, 0.8745, 0.8745, 0.8431, 0.8353, 0.6431, 0.4980,\n",
       "          0.4824, 0.7686, 0.8980, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7176, 0.8824, 0.8471,\n",
       "          0.8745, 0.8941, 0.9216, 0.8902, 0.8784, 0.8706, 0.8784, 0.8667,\n",
       "          0.8745, 0.9608, 0.6784, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7569, 0.8941, 0.8549,\n",
       "          0.8353, 0.7765, 0.7059, 0.8314, 0.8235, 0.8275, 0.8353, 0.8745,\n",
       "          0.8627, 0.9529, 0.7922, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0039, 0.0118, 0.0000, 0.0471, 0.8588, 0.8627, 0.8314,\n",
       "          0.8549, 0.7529, 0.6627, 0.8902, 0.8157, 0.8549, 0.8784, 0.8314,\n",
       "          0.8863, 0.7725, 0.8196, 0.2039],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0235, 0.0000, 0.3882, 0.9569, 0.8706, 0.8627,\n",
       "          0.8549, 0.7961, 0.7765, 0.8667, 0.8431, 0.8353, 0.8706, 0.8627,\n",
       "          0.9608, 0.4667, 0.6549, 0.2196],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0157, 0.0000, 0.0000, 0.2157, 0.9255, 0.8941, 0.9020,\n",
       "          0.8941, 0.9412, 0.9098, 0.8353, 0.8549, 0.8745, 0.9176, 0.8510,\n",
       "          0.8510, 0.8196, 0.3608, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0039, 0.0157, 0.0235, 0.0275, 0.0078, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.8863, 0.8510, 0.8745,\n",
       "          0.8706, 0.8588, 0.8706, 0.8667, 0.8471, 0.8745, 0.8980, 0.8431,\n",
       "          0.8549, 1.0000, 0.3020, 0.0000],\n",
       "         [0.0000, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.2431, 0.5686, 0.8000, 0.8941, 0.8118, 0.8353, 0.8667,\n",
       "          0.8549, 0.8157, 0.8275, 0.8549, 0.8784, 0.8745, 0.8588, 0.8431,\n",
       "          0.8784, 0.9569, 0.6235, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.1725, 0.3216, 0.4196,\n",
       "          0.7412, 0.8941, 0.8627, 0.8706, 0.8510, 0.8863, 0.7843, 0.8039,\n",
       "          0.8275, 0.9020, 0.8784, 0.9176, 0.6902, 0.7373, 0.9804, 0.9725,\n",
       "          0.9137, 0.9333, 0.8431, 0.0000],\n",
       "         [0.0000, 0.2235, 0.7333, 0.8157, 0.8784, 0.8667, 0.8784, 0.8157,\n",
       "          0.8000, 0.8392, 0.8157, 0.8196, 0.7843, 0.6235, 0.9608, 0.7569,\n",
       "          0.8078, 0.8745, 1.0000, 1.0000, 0.8667, 0.9176, 0.8667, 0.8275,\n",
       "          0.8627, 0.9098, 0.9647, 0.0000],\n",
       "         [0.0118, 0.7922, 0.8941, 0.8784, 0.8667, 0.8275, 0.8275, 0.8392,\n",
       "          0.8039, 0.8039, 0.8039, 0.8627, 0.9412, 0.3137, 0.5882, 1.0000,\n",
       "          0.8980, 0.8667, 0.7373, 0.6039, 0.7490, 0.8235, 0.8000, 0.8196,\n",
       "          0.8706, 0.8941, 0.8824, 0.0000],\n",
       "         [0.3843, 0.9137, 0.7765, 0.8235, 0.8706, 0.8980, 0.8980, 0.9176,\n",
       "          0.9765, 0.8627, 0.7608, 0.8431, 0.8510, 0.9451, 0.2549, 0.2863,\n",
       "          0.4157, 0.4588, 0.6588, 0.8588, 0.8667, 0.8431, 0.8510, 0.8745,\n",
       "          0.8745, 0.8784, 0.8980, 0.1137],\n",
       "         [0.2941, 0.8000, 0.8314, 0.8000, 0.7569, 0.8039, 0.8275, 0.8824,\n",
       "          0.8471, 0.7255, 0.7725, 0.8078, 0.7765, 0.8353, 0.9412, 0.7647,\n",
       "          0.8902, 0.9608, 0.9373, 0.8745, 0.8549, 0.8314, 0.8196, 0.8706,\n",
       "          0.8627, 0.8667, 0.9020, 0.2627],\n",
       "         [0.1882, 0.7961, 0.7176, 0.7608, 0.8353, 0.7725, 0.7255, 0.7451,\n",
       "          0.7608, 0.7529, 0.7922, 0.8392, 0.8588, 0.8667, 0.8627, 0.9255,\n",
       "          0.8824, 0.8471, 0.7804, 0.8078, 0.7294, 0.7098, 0.6941, 0.6745,\n",
       "          0.7098, 0.8039, 0.8078, 0.4510],\n",
       "         [0.0000, 0.4784, 0.8588, 0.7569, 0.7020, 0.6706, 0.7176, 0.7686,\n",
       "          0.8000, 0.8235, 0.8353, 0.8118, 0.8275, 0.8235, 0.7843, 0.7686,\n",
       "          0.7608, 0.7490, 0.7647, 0.7490, 0.7765, 0.7529, 0.6902, 0.6118,\n",
       "          0.6549, 0.6941, 0.8235, 0.3608],\n",
       "         [0.0000, 0.0000, 0.2902, 0.7412, 0.8314, 0.7490, 0.6863, 0.6745,\n",
       "          0.6863, 0.7098, 0.7255, 0.7373, 0.7412, 0.7373, 0.7569, 0.7765,\n",
       "          0.8000, 0.8196, 0.8235, 0.8235, 0.8275, 0.7373, 0.7373, 0.7608,\n",
       "          0.7529, 0.8471, 0.6667, 0.0000],\n",
       "         [0.0078, 0.0000, 0.0000, 0.0000, 0.2588, 0.7843, 0.8706, 0.9294,\n",
       "          0.9373, 0.9490, 0.9647, 0.9529, 0.9569, 0.8667, 0.8627, 0.7569,\n",
       "          0.7490, 0.7020, 0.7137, 0.7137, 0.7098, 0.6902, 0.6510, 0.6588,\n",
       "          0.3882, 0.2275, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1569,\n",
       "          0.2392, 0.1725, 0.2824, 0.1608, 0.1373, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000]]])"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 69
  },
  {
   "cell_type": "code",
   "source": [
    "#计算均值和方差\n",
    "def cal_mean_std(ds):\n",
    "    mean = 0.\n",
    "    std = 0.\n",
    "    for img, _ in ds: # 遍历每张图片,img.shape=[1,28,28]\n",
    "        mean += img.mean(dim=(1, 2)) # 计算每张图片功能，dim=(1, 2)表示在1和2维度上求均值\n",
    "        std += img.std(dim=(1, 2))\n",
    "    mean /= len(ds)\n",
    "    std /= len(ds)\n",
    "    return mean, std\n",
    "\n",
    "\n",
    "print(cal_mean_std(train_ds))\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-02-21T13:46:39.420014Z",
     "start_time": "2025-02-21T13:46:27.361065Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([0.2860]), tensor([0.3205]))\n"
     ]
    }
   ],
   "execution_count": 60
  },
  {
   "cell_type": "code",
   "source": [
    "type(img)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-02-21T13:46:39.437720Z",
     "start_time": "2025-02-21T13:46:39.421794Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Tensor"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 61
  },
  {
   "cell_type": "code",
   "source": [
    "label"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-02-21T13:46:39.453305Z",
     "start_time": "2025-02-21T13:46:39.440565Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 62
  },
  {
   "cell_type": "code",
   "source": [
    "type(img) #tensor中文是 张量,和numpy的ndarray类似"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-02-21T13:46:39.470845Z",
     "start_time": "2025-02-21T13:46:39.457427Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Tensor"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 63
  },
  {
   "cell_type": "code",
   "source": [
    "label"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-02-21T13:46:39.486441Z",
     "start_time": "2025-02-21T13:46:39.474188Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 64
  },
  {
   "cell_type": "code",
   "source": [
    "# 显示图片，这里需要把transforms.ToTensor(),进行归一化注释掉，否则是不行的\n",
    "def show_img_content(img):\n",
    "    from PIL import Image\n",
    "\n",
    "    # 打开一个图像文件\n",
    "    # img = Image.open(img)\n",
    "\n",
    "\n",
    "    print(\"图像大小:\", img.size)\n",
    "    print(\"图像模式:\", img.mode)\n",
    "\n",
    "\n",
    "    # 如果图像是单通道的，比如灰度图，你可以这样获取像素值列表：\n",
    "    if img.mode == 'L':\n",
    "        pixel_values = list(img.getdata())\n",
    "        print(pixel_values)\n",
    "show_img_content(img) #这里必须把上面的 transforms.ToTensor(), # 转换为tensor，进行归一化注释掉，否则是不行的"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-02-21T13:46:39.504361Z",
     "start_time": "2025-02-21T13:46:39.490751Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "图像大小: <built-in method size of Tensor object at 0x000001B94489D440>\n",
      "图像模式: <built-in method mode of Tensor object at 0x000001B94489D440>\n"
     ]
    }
   ],
   "execution_count": 65
  },
  {
   "cell_type": "code",
   "source": [
    "#这个代码必须是注释了上面的 transforms.ToTensor()才能够运行的\n",
    "def show_single_image(img_arr):\n",
    "    plt.imshow(img_arr, cmap=\"binary\") # 显示图片\n",
    "    plt.colorbar() # 显示颜色条\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show_single_image(img)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2025-02-21T13:49:53.598046Z",
     "start_time": "2025-02-21T13:49:53.314770Z"
    }
   },
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "Invalid shape (1, 28, 28) for image data",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mTypeError\u001B[0m                                 Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[70], line 8\u001B[0m\n\u001B[0;32m      4\u001B[0m     plt\u001B[38;5;241m.\u001B[39mcolorbar() \u001B[38;5;66;03m# 显示颜色条\u001B[39;00m\n\u001B[0;32m      5\u001B[0m     plt\u001B[38;5;241m.\u001B[39mshow()\n\u001B[1;32m----> 8\u001B[0m \u001B[43mshow_single_image\u001B[49m\u001B[43m(\u001B[49m\u001B[43mimg\u001B[49m\u001B[43m)\u001B[49m\n",
      "Cell \u001B[1;32mIn[70], line 3\u001B[0m, in \u001B[0;36mshow_single_image\u001B[1;34m(img_arr)\u001B[0m\n\u001B[0;32m      2\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mshow_single_image\u001B[39m(img_arr):\n\u001B[1;32m----> 3\u001B[0m     \u001B[43mplt\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mimshow\u001B[49m\u001B[43m(\u001B[49m\u001B[43mimg_arr\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcmap\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mbinary\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m \u001B[38;5;66;03m# 显示图片\u001B[39;00m\n\u001B[0;32m      4\u001B[0m     plt\u001B[38;5;241m.\u001B[39mcolorbar() \u001B[38;5;66;03m# 显示颜色条\u001B[39;00m\n\u001B[0;32m      5\u001B[0m     plt\u001B[38;5;241m.\u001B[39mshow()\n",
      "File \u001B[1;32mD:\\python3.10\\lib\\site-packages\\matplotlib\\pyplot.py:3592\u001B[0m, in \u001B[0;36mimshow\u001B[1;34m(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, colorizer, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, data, **kwargs)\u001B[0m\n\u001B[0;32m   3570\u001B[0m \u001B[38;5;129m@_copy_docstring_and_deprecators\u001B[39m(Axes\u001B[38;5;241m.\u001B[39mimshow)\n\u001B[0;32m   3571\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mimshow\u001B[39m(\n\u001B[0;32m   3572\u001B[0m     X: ArrayLike \u001B[38;5;241m|\u001B[39m PIL\u001B[38;5;241m.\u001B[39mImage\u001B[38;5;241m.\u001B[39mImage,\n\u001B[1;32m   (...)\u001B[0m\n\u001B[0;32m   3590\u001B[0m     \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs,\n\u001B[0;32m   3591\u001B[0m ) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m AxesImage:\n\u001B[1;32m-> 3592\u001B[0m     __ret \u001B[38;5;241m=\u001B[39m gca()\u001B[38;5;241m.\u001B[39mimshow(\n\u001B[0;32m   3593\u001B[0m         X,\n\u001B[0;32m   3594\u001B[0m         cmap\u001B[38;5;241m=\u001B[39mcmap,\n\u001B[0;32m   3595\u001B[0m         norm\u001B[38;5;241m=\u001B[39mnorm,\n\u001B[0;32m   3596\u001B[0m         aspect\u001B[38;5;241m=\u001B[39maspect,\n\u001B[0;32m   3597\u001B[0m         interpolation\u001B[38;5;241m=\u001B[39minterpolation,\n\u001B[0;32m   3598\u001B[0m         alpha\u001B[38;5;241m=\u001B[39malpha,\n\u001B[0;32m   3599\u001B[0m         vmin\u001B[38;5;241m=\u001B[39mvmin,\n\u001B[0;32m   3600\u001B[0m         vmax\u001B[38;5;241m=\u001B[39mvmax,\n\u001B[0;32m   3601\u001B[0m         colorizer\u001B[38;5;241m=\u001B[39mcolorizer,\n\u001B[0;32m   3602\u001B[0m         origin\u001B[38;5;241m=\u001B[39morigin,\n\u001B[0;32m   3603\u001B[0m         extent\u001B[38;5;241m=\u001B[39mextent,\n\u001B[0;32m   3604\u001B[0m         interpolation_stage\u001B[38;5;241m=\u001B[39minterpolation_stage,\n\u001B[0;32m   3605\u001B[0m         filternorm\u001B[38;5;241m=\u001B[39mfilternorm,\n\u001B[0;32m   3606\u001B[0m         filterrad\u001B[38;5;241m=\u001B[39mfilterrad,\n\u001B[0;32m   3607\u001B[0m         resample\u001B[38;5;241m=\u001B[39mresample,\n\u001B[0;32m   3608\u001B[0m         url\u001B[38;5;241m=\u001B[39murl,\n\u001B[0;32m   3609\u001B[0m         \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39m({\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mdata\u001B[39m\u001B[38;5;124m\"\u001B[39m: data} \u001B[38;5;28;01mif\u001B[39;00m data \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;28;01melse\u001B[39;00m {}),\n\u001B[0;32m   3610\u001B[0m         \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs,\n\u001B[0;32m   3611\u001B[0m     )\n\u001B[0;32m   3612\u001B[0m     sci(__ret)\n\u001B[0;32m   3613\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m __ret\n",
      "File \u001B[1;32mD:\\python3.10\\lib\\site-packages\\matplotlib\\__init__.py:1521\u001B[0m, in \u001B[0;36m_preprocess_data.<locals>.inner\u001B[1;34m(ax, data, *args, **kwargs)\u001B[0m\n\u001B[0;32m   1518\u001B[0m \u001B[38;5;129m@functools\u001B[39m\u001B[38;5;241m.\u001B[39mwraps(func)\n\u001B[0;32m   1519\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21minner\u001B[39m(ax, \u001B[38;5;241m*\u001B[39margs, data\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m   1520\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m data \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m-> 1521\u001B[0m         \u001B[38;5;28;01mreturn\u001B[39;00m func(\n\u001B[0;32m   1522\u001B[0m             ax,\n\u001B[0;32m   1523\u001B[0m             \u001B[38;5;241m*\u001B[39m\u001B[38;5;28mmap\u001B[39m(cbook\u001B[38;5;241m.\u001B[39msanitize_sequence, args),\n\u001B[0;32m   1524\u001B[0m             \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39m{k: cbook\u001B[38;5;241m.\u001B[39msanitize_sequence(v) \u001B[38;5;28;01mfor\u001B[39;00m k, v \u001B[38;5;129;01min\u001B[39;00m kwargs\u001B[38;5;241m.\u001B[39mitems()})\n\u001B[0;32m   1526\u001B[0m     bound \u001B[38;5;241m=\u001B[39m new_sig\u001B[38;5;241m.\u001B[39mbind(ax, \u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m   1527\u001B[0m     auto_label \u001B[38;5;241m=\u001B[39m (bound\u001B[38;5;241m.\u001B[39marguments\u001B[38;5;241m.\u001B[39mget(label_namer)\n\u001B[0;32m   1528\u001B[0m                   \u001B[38;5;129;01mor\u001B[39;00m bound\u001B[38;5;241m.\u001B[39mkwargs\u001B[38;5;241m.\u001B[39mget(label_namer))\n",
      "File \u001B[1;32mD:\\python3.10\\lib\\site-packages\\matplotlib\\axes\\_axes.py:5945\u001B[0m, in \u001B[0;36mAxes.imshow\u001B[1;34m(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, colorizer, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs)\u001B[0m\n\u001B[0;32m   5942\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m aspect \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m   5943\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mset_aspect(aspect)\n\u001B[1;32m-> 5945\u001B[0m \u001B[43mim\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mset_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m   5946\u001B[0m im\u001B[38;5;241m.\u001B[39mset_alpha(alpha)\n\u001B[0;32m   5947\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m im\u001B[38;5;241m.\u001B[39mget_clip_path() \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m   5948\u001B[0m     \u001B[38;5;66;03m# image does not already have clipping set, clip to Axes patch\u001B[39;00m\n",
      "File \u001B[1;32mD:\\python3.10\\lib\\site-packages\\matplotlib\\image.py:675\u001B[0m, in \u001B[0;36m_ImageBase.set_data\u001B[1;34m(self, A)\u001B[0m\n\u001B[0;32m    673\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(A, PIL\u001B[38;5;241m.\u001B[39mImage\u001B[38;5;241m.\u001B[39mImage):\n\u001B[0;32m    674\u001B[0m     A \u001B[38;5;241m=\u001B[39m pil_to_array(A)  \u001B[38;5;66;03m# Needed e.g. to apply png palette.\u001B[39;00m\n\u001B[1;32m--> 675\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_normalize_image_array\u001B[49m\u001B[43m(\u001B[49m\u001B[43mA\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m    676\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_imcache \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m    677\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mstale \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n",
      "File \u001B[1;32mD:\\python3.10\\lib\\site-packages\\matplotlib\\image.py:643\u001B[0m, in \u001B[0;36m_ImageBase._normalize_image_array\u001B[1;34m(A)\u001B[0m\n\u001B[0;32m    641\u001B[0m     A \u001B[38;5;241m=\u001B[39m A\u001B[38;5;241m.\u001B[39msqueeze(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m)  \u001B[38;5;66;03m# If just (M, N, 1), assume scalar and apply colormap.\u001B[39;00m\n\u001B[0;32m    642\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (A\u001B[38;5;241m.\u001B[39mndim \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m2\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m A\u001B[38;5;241m.\u001B[39mndim \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m3\u001B[39m \u001B[38;5;129;01mand\u001B[39;00m A\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m] \u001B[38;5;129;01min\u001B[39;00m [\u001B[38;5;241m3\u001B[39m, \u001B[38;5;241m4\u001B[39m]):\n\u001B[1;32m--> 643\u001B[0m     \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mInvalid shape \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mA\u001B[38;5;241m.\u001B[39mshape\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m for image data\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m    644\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m A\u001B[38;5;241m.\u001B[39mndim \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m3\u001B[39m:\n\u001B[0;32m    645\u001B[0m     \u001B[38;5;66;03m# If the input data has values outside the valid range (after\u001B[39;00m\n\u001B[0;32m    646\u001B[0m     \u001B[38;5;66;03m# normalisation), we issue a warning and then clip X to the bounds\u001B[39;00m\n\u001B[0;32m    647\u001B[0m     \u001B[38;5;66;03m# - otherwise casting wraps extreme values, hiding outliers and\u001B[39;00m\n\u001B[0;32m    648\u001B[0m     \u001B[38;5;66;03m# making reliable interpretation impossible.\u001B[39;00m\n\u001B[0;32m    649\u001B[0m     high \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m255\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m np\u001B[38;5;241m.\u001B[39missubdtype(A\u001B[38;5;241m.\u001B[39mdtype, np\u001B[38;5;241m.\u001B[39minteger) \u001B[38;5;28;01melse\u001B[39;00m \u001B[38;5;241m1\u001B[39m\n",
      "\u001B[1;31mTypeError\u001B[0m: Invalid shape (1, 28, 28) for image data"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ],
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAAGiCAYAAACGUJO6AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAGwdJREFUeJzt3X9M3dX9x/EX0HKpsdA6xoWyq6x1/ralgmVYG+dyJ4kG1z8WmTWFEX9MZUZ7s9liW1Crpau2I7NoY9XpHzqqRo2xBKdMYlSWRloSnW1NpRVmvLclrtyOKrTc8/1j316HBcsH+dG3PB/J5w/OPud+zj1h9+m9vfeS4JxzAgDAmMSJXgAAACNBwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmeQ7Y22+/reLiYs2aNUsJCQl65ZVXTjqnublZl1xyiXw+n84++2w9/fTTI1gqAABf8xywnp4ezZs3T3V1dcM6f9++fbrmmmt05ZVXqq2tTXfddZduuukmvf76654XCwDAcQnf5ct8ExIS9PLLL2vx4sVDnrN8+XJt27ZNH374YXzs17/+tQ4dOqTGxsaRXhoAMMlNGesLtLS0KBgMDhgrKirSXXfdNeSc3t5e9fb2xn+OxWL64osv9IMf/EAJCQljtVQAwBhwzunw4cOaNWuWEhNH760XYx6wcDgsv98/YMzv9ysajerLL7/UtGnTTphTU1Oj++67b6yXBgAYR52dnfrRj340arc35gEbicrKSoVCofjP3d3dOvPMM9XZ2anU1NQJXBkAwKtoNKpAIKDp06eP6u2OecAyMzMViUQGjEUiEaWmpg767EuSfD6ffD7fCeOpqakEDACMGu1/Ahrzz4EVFhaqqalpwNgbb7yhwsLCsb40AOB7zHPA/vOf/6itrU1tbW2S/vs2+ba2NnV0dEj678t/paWl8fNvvfVWtbe36+6779bu3bv16KOP6vnnn9eyZctG5x4AACYlzwF7//33NX/+fM2fP1+SFAqFNH/+fFVVVUmSPv/883jMJOnHP/6xtm3bpjfeeEPz5s3Thg0b9MQTT6ioqGiU7gIAYDL6Tp8DGy/RaFRpaWnq7u7m38AAwJixegznuxABACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGDSiAJWV1ennJwcpaSkqKCgQNu3b//W82tra3Xuuedq2rRpCgQCWrZsmb766qsRLRgAAGkEAdu6datCoZCqq6u1Y8cOzZs3T0VFRTpw4MCg5z/33HNasWKFqqurtWvXLj355JPaunWr7rnnnu+8eADA5OU5YBs3btTNN9+s8vJyXXDBBdq8ebNOO+00PfXUU4Oe/95772nhwoVasmSJcnJydNVVV+n6668/6bM2AAC+jaeA9fX1qbW1VcFg8OsbSExUMBhUS0vLoHMuu+wytba2xoPV3t6uhoYGXX311UNep7e3V9FodMABAMD/muLl5K6uLvX398vv9w8Y9/v92r1796BzlixZoq6uLl1++eVyzunYsWO69dZbv/UlxJqaGt13331elgYAmGTG/F2Izc3NWrt2rR599FHt2LFDL730krZt26Y1a9YMOaeyslLd3d3xo7Ozc6yXCQAwxtMzsPT0dCUlJSkSiQwYj0QiyszMHHTO6tWrtXTpUt10002SpIsvvlg9PT265ZZbtHLlSiUmnthQn88nn8/nZWkAgEnG0zOw5ORk5eXlqampKT4Wi8XU1NSkwsLCQeccOXLkhEglJSVJkpxzXtcLAIAkj8/AJCkUCqmsrEz5+flasGCBamtr1dPTo/LycklSaWmpsrOzVVNTI0kqLi7Wxo0bNX/+fBUUFGjv3r1avXq1iouL4yEDAMArzwErKSnRwYMHVVVVpXA4rNzcXDU2Nsbf2NHR0THgGdeqVauUkJCgVatW6bPPPtMPf/hDFRcX68EHHxy9ewEAmHQSnIHX8aLRqNLS0tTd3a3U1NSJXg4AwIOxegznuxABACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGDSiAJWV1ennJwcpaSkqKCgQNu3b//W8w8dOqSKigplZWXJ5/PpnHPOUUNDw4gWDACAJE3xOmHr1q0KhULavHmzCgoKVFtbq6KiIu3Zs0cZGRknnN/X16df/OIXysjI0Isvvqjs7Gx9+umnmjFjxmisHwAwSSU455yXCQUFBbr00ku1adMmSVIsFlMgENAdd9yhFStWnHD+5s2b9dBDD2n37t2aOnXqiBYZjUaVlpam7u5upaamjug2AAATY6wewz29hNjX16fW1lYFg8GvbyAxUcFgUC0tLYPOefXVV1VYWKiKigr5/X5ddNFFWrt2rfr7+4e8Tm9vr6LR6IADAID/5SlgXV1d6u/vl9/vHzDu9/sVDocHndPe3q4XX3xR/f39amho0OrVq7VhwwY98MADQ16npqZGaWlp8SMQCHhZJgBgEhjzdyHGYjFlZGTo8ccfV15enkpKSrRy5Upt3rx5yDmVlZXq7u6OH52dnWO9TACAMZ7exJGenq6kpCRFIpEB45FIRJmZmYPOycrK0tSpU5WUlBQfO//88xUOh9XX16fk5OQT5vh8Pvl8Pi9LAwBMMp6egSUnJysvL09NTU3xsVgspqamJhUWFg46Z+HChdq7d69isVh87OOPP1ZWVtag8QIAYDg8v4QYCoW0ZcsWPfPMM9q1a5duu+029fT0qLy8XJJUWlqqysrK+Pm33XabvvjiC9155536+OOPtW3bNq1du1YVFRWjdy8AAJOO58+BlZSU6ODBg6qqqlI4HFZubq4aGxvjb+zo6OhQYuLXXQwEAnr99de1bNkyzZ07V9nZ2brzzju1fPny0bsXAIBJx/PnwCYCnwMDALtOic+BAQBwqiBgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwKQRBayurk45OTlKSUlRQUGBtm/fPqx59fX1SkhI0OLFi0dyWQAA4jwHbOvWrQqFQqqurtaOHTs0b948FRUV6cCBA986b//+/fr973+vRYsWjXixAAAc5zlgGzdu1M0336zy8nJdcMEF2rx5s0477TQ99dRTQ87p7+/XDTfcoPvuu0+zZ88+6TV6e3sVjUYHHAAA/C9PAevr61Nra6uCweDXN5CYqGAwqJaWliHn3X///crIyNCNN944rOvU1NQoLS0tfgQCAS/LBABMAp4C1tXVpf7+fvn9/gHjfr9f4XB40DnvvPOOnnzySW3ZsmXY16msrFR3d3f86Ozs9LJMAMAkMGUsb/zw4cNaunSptmzZovT09GHP8/l88vl8Y7gyAIB1ngKWnp6upKQkRSKRAeORSESZmZknnP/JJ59o//79Ki4ujo/FYrH/XnjKFO3Zs0dz5swZyboBAJOcp5cQk5OTlZeXp6ampvhYLBZTU1OTCgsLTzj/vPPO0wcffKC2trb4ce211+rKK69UW1sb/7YFABgxzy8hhkIhlZWVKT8/XwsWLFBtba16enpUXl4uSSotLVV2drZqamqUkpKiiy66aMD8GTNmSNIJ4wAAeOE5YCUlJTp48KCqqqoUDoeVm5urxsbG+Bs7Ojo6lJjIF3wAAMZWgnPOTfQiTiYajSotLU3d3d1KTU2d6OUAADwYq8dwnioBAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMCkEQWsrq5OOTk5SklJUUFBgbZv3z7kuVu2bNGiRYs0c+ZMzZw5U8Fg8FvPBwBgODwHbOvWrQqFQqqurtaOHTs0b948FRUV6cCBA4Oe39zcrOuvv15vvfWWWlpaFAgEdNVVV+mzzz77zosHAExeCc4552VCQUGBLr30Um3atEmSFIvFFAgEdMcdd2jFihUnnd/f36+ZM2dq06ZNKi0tHfSc3t5e9fb2xn+ORqMKBALq7u5Wamqql+UCACZYNBpVWlraqD+Ge3oG1tfXp9bWVgWDwa9vIDFRwWBQLS0tw7qNI0eO6OjRozrjjDOGPKempkZpaWnxIxAIeFkmAGAS8BSwrq4u9ff3y+/3Dxj3+/0Kh8PDuo3ly5dr1qxZAyL4TZWVleru7o4fnZ2dXpYJAJgEpoznxdatW6f6+no1NzcrJSVlyPN8Pp98Pt84rgwAYI2ngKWnpyspKUmRSGTAeCQSUWZm5rfOffjhh7Vu3Tq9+eabmjt3rveVAgDwPzy9hJicnKy8vDw1NTXFx2KxmJqamlRYWDjkvPXr12vNmjVqbGxUfn7+yFcLAMD/8/wSYigUUllZmfLz87VgwQLV1taqp6dH5eXlkqTS0lJlZ2erpqZGkvTHP/5RVVVVeu6555STkxP/t7LTTz9dp59++ijeFQDAZOI5YCUlJTp48KCqqqoUDoeVm5urxsbG+Bs7Ojo6lJj49RO7xx57TH19ffrVr3414Haqq6t17733frfVAwAmLc+fA5sIY/UZAgDA2DslPgcGAMCpgoABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAk0YUsLq6OuXk5CglJUUFBQXavn37t57/wgsv6LzzzlNKSoouvvhiNTQ0jGixAAAc5zlgW7duVSgUUnV1tXbs2KF58+apqKhIBw4cGPT89957T9dff71uvPFG7dy5U4sXL9bixYv14YcffufFAwAmrwTnnPMyoaCgQJdeeqk2bdokSYrFYgoEArrjjju0YsWKE84vKSlRT0+PXnvttfjYT3/6U+Xm5mrz5s2DXqO3t1e9vb3xn7u7u3XmmWeqs7NTqampXpYLAJhg0WhUgUBAhw4dUlpa2ujdsPOgt7fXJSUluZdffnnAeGlpqbv22msHnRMIBNyf/vSnAWNVVVVu7ty5Q16nurraSeLg4ODg+B4dn3zyiZfknNQUedDV1aX+/n75/f4B436/X7t37x50TjgcHvT8cDg85HUqKysVCoXiPx86dEhnnXWWOjo6Rrfe3zPH/yuHZ6rfjn06OfZoeNin4Tn+KtoZZ5wxqrfrKWDjxefzyefznTCelpbGL8kwpKamsk/DwD6dHHs0POzT8CQmju4b3z3dWnp6upKSkhSJRAaMRyIRZWZmDjonMzPT0/kAAAyHp4AlJycrLy9PTU1N8bFYLKampiYVFhYOOqewsHDA+ZL0xhtvDHk+AADD4fklxFAopLKyMuXn52vBggWqra1VT0+PysvLJUmlpaXKzs5WTU2NJOnOO+/UFVdcoQ0bNuiaa65RfX293n//fT3++OPDvqbP51N1dfWgLyvia+zT8LBPJ8ceDQ/7NDxjtU+e30YvSZs2bdJDDz2kcDis3Nxc/fnPf1ZBQYEk6Wc/+5lycnL09NNPx89/4YUXtGrVKu3fv18/+clPtH79el199dWjdicAAJPPiAIGAMBE47sQAQAmETAAgEkEDABgEgEDAJh0ygSMP9EyPF72acuWLVq0aJFmzpypmTNnKhgMnnRfvw+8/i4dV19fr4SEBC1evHhsF3iK8LpPhw4dUkVFhbKysuTz+XTOOedMiv/fed2n2tpanXvuuZo2bZoCgYCWLVumr776apxWOzHefvttFRcXa9asWUpISNArr7xy0jnNzc265JJL5PP5dPbZZw945/qwjeo3K45QfX29S05Odk899ZT75z//6W6++WY3Y8YMF4lEBj3/3XffdUlJSW79+vXuo48+cqtWrXJTp051H3zwwTivfHx53aclS5a4uro6t3PnTrdr1y73m9/8xqWlpbl//etf47zy8eN1j47bt2+fy87OdosWLXK//OUvx2exE8jrPvX29rr8/Hx39dVXu3feecft27fPNTc3u7a2tnFe+fjyuk/PPvus8/l87tlnn3X79u1zr7/+usvKynLLli0b55WPr4aGBrdy5Ur30ksvOUknfOH7N7W3t7vTTjvNhUIh99FHH7lHHnnEJSUlucbGRk/XPSUCtmDBAldRURH/ub+/382aNcvV1NQMev51113nrrnmmgFjBQUF7re//e2YrnOied2nbzp27JibPn26e+aZZ8ZqiRNuJHt07Ngxd9lll7knnnjClZWVTYqAed2nxx57zM2ePdv19fWN1xJPCV73qaKiwv385z8fMBYKhdzChQvHdJ2nkuEE7O6773YXXnjhgLGSkhJXVFTk6VoT/hJiX1+fWltbFQwG42OJiYkKBoNqaWkZdE5LS8uA8yWpqKhoyPO/D0ayT9905MgRHT16dNS/EfpUMdI9uv/++5WRkaEbb7xxPJY54UayT6+++qoKCwtVUVEhv9+viy66SGvXrlV/f/94LXvcjWSfLrvsMrW2tsZfZmxvb1dDQwNf3PANo/UYPuHfRj9ef6LFupHs0zctX75cs2bNOuEX5/tiJHv0zjvv6Mknn1RbW9s4rPDUMJJ9am9v19///nfdcMMNamho0N69e3X77bfr6NGjqq6uHo9lj7uR7NOSJUvU1dWlyy+/XM45HTt2TLfeeqvuueee8ViyGUM9hkejUX355ZeaNm3asG5nwp+BYXysW7dO9fX1evnll5WSkjLRyzklHD58WEuXLtWWLVuUnp4+0cs5pcViMWVkZOjxxx9XXl6eSkpKtHLlyiH/qvpk1dzcrLVr1+rRRx/Vjh079NJLL2nbtm1as2bNRC/te2nCn4HxJ1qGZyT7dNzDDz+sdevW6c0339TcuXPHcpkTyuseffLJJ9q/f7+Ki4vjY7FYTJI0ZcoU7dmzR3PmzBnbRU+AkfwuZWVlaerUqUpKSoqPnX/++QqHw+rr61NycvKYrnkijGSfVq9eraVLl+qmm26SJF188cXq6enRLbfcopUrV47638OyaqjH8NTU1GE/+5JOgWdg/ImW4RnJPknS+vXrtWbNGjU2Nio/P388ljphvO7Reeedpw8++EBtbW3x49prr9WVV16ptrY2BQKB8Vz+uBnJ79LChQu1d+/eeOAl6eOPP1ZWVtb3Ml7SyPbpyJEjJ0TqePQdXzsbN2qP4d7eXzI26uvrnc/nc08//bT76KOP3C233OJmzJjhwuGwc865pUuXuhUrVsTPf/fdd92UKVPcww8/7Hbt2uWqq6snzdvovezTunXrXHJysnvxxRfd559/Hj8OHz48UXdhzHndo2+aLO9C9LpPHR0dbvr06e53v/ud27Nnj3vttddcRkaGe+CBBybqLowLr/tUXV3tpk+f7v7617+69vZ297e//c3NmTPHXXfddRN1F8bF4cOH3c6dO93OnTudJLdx40a3c+dO9+mnnzrnnFuxYoVbunRp/Pzjb6P/wx/+4Hbt2uXq6ursvo3eOeceeeQRd+aZZ7rk5GS3YMEC949//CP+v11xxRWurKxswPnPP/+8O+ecc1xycrK78MIL3bZt28Z5xRPDyz6dddZZTtIJR3V19fgvfBx5/V36X5MlYM5536f33nvPFRQUOJ/P52bPnu0efPBBd+zYsXFe9fjzsk9Hjx519957r5szZ45LSUlxgUDA3X777e7f//73+C98HL311luDPtYc35uysjJ3xRVXnDAnNzfXJScnu9mzZ7u//OUvnq/Ln1MBAJg04f8GBgDASBAwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABg0v8Bc0z++5j1+JwAAAAASUVORK5CYII="
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 70
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "def show_imgs(n_rows, n_cols, train_ds, class_names):\n",
    "    assert n_rows * n_cols < len(train_ds)  #确保打印的图片小于总样本数\n",
    "    plt.figure(figsize = (n_cols * 1.4, n_rows * 1.6))  #宽1.4高1.6，宽，高\n",
    "    for row in range(n_rows):\n",
    "        for col in range(n_cols):\n",
    "            index = n_cols * row + col  # 计算索引，从0开始\n",
    "            plt.subplot(n_rows, n_cols, index+1)#因为从1开始\n",
    "            img_arr, label = train_ds[index]\n",
    "            img_arr = np.transpose(img_arr, (1, 2, 0))  # 通道换到最后一维\n",
    "            plt.imshow(img_arr, cmap=\"binary\",\n",
    "                       interpolation = 'nearest')#interpolation='nearest'是临近插值\n",
    "            plt.axis('off')#去除坐标系\n",
    "            plt.title(class_names[label]) # 显示类别名称\n",
    "    plt.show()\n",
    "    \n",
    "    \n",
    "\n",
    "#已知的图片类别\n",
    "# lables在这个路径https://github.com/zalandoresearch/fashion-mnist\n",
    "class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress',\n",
    "               'Coat', 'Sandal', 'Shirt', 'Sneaker',\n",
    "               'Bag', 'Ankle boot'] #0-9分别代表的类别\n",
    "#只是打印了前15个样本\n",
    "show_imgs(3, 5, train_ds, class_names)\n"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# 从数据集到dataloader\n",
    "train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True) #batch_size分批，shuffle洗牌\n",
    "val_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "在PyTorch中，`DataLoader`是一个迭代器，它封装了数据的加载和预处理过程，使得在训练机器学习模型时可以方便地批量加载数据。`DataLoader`主要负责以下几个方面：\n",
    "\n",
    "1. **批量加载数据**：`DataLoader`可以将数据集（Dataset）切分为更小的批次（batch），每次迭代提供一小批量数据，而不是单个数据点。这有助于模型学习数据中的统计依赖性，并且可以更高效地利用GPU等硬件的并行计算能力。\n",
    "\n",
    "2. **数据打乱**：默认情况下，`DataLoader`会在每个epoch（训练周期）开始时打乱数据的顺序。这有助于模型训练时避免陷入局部最优解，并且可以提高模型的泛化能力。\n",
    "\n",
    "3. **多线程数据加载**：`DataLoader`支持多线程（通过参数`num_workers`）来并行地加载数据，这可以显著减少训练过程中的等待时间，尤其是在处理大规模数据集时。\n",
    "\n",
    "4. **数据预处理**：`DataLoader`可以与`transforms`结合使用，对加载的数据进行预处理，如归一化、标准化、数据增强等操作。\n",
    "\n",
    "5. **内存管理**：`DataLoader`负责管理数据的内存使用，确保在训练过程中不会耗尽内存资源。\n",
    "\n",
    "6. **易用性**：`DataLoader`提供了一个简单的接口，可以很容易地集成到训练循环中。\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "for datas, labels in train_loader:\n",
    "    print(datas.shape)\n",
    "    print(labels.shape)\n",
    "    break"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "# 查看val_loader",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "for datas, labels in val_loader:\n",
    "    print(datas.shape)\n",
    "    print(labels.shape)\n",
    "    break"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "class NeuralNetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__() # 继承父类的初始化方法，子类有父类的属性\n",
    "        self.flatten = nn.Flatten()  # 展平层\n",
    "        self.linear_relu_stack = nn.Sequential(\n",
    "            nn.Linear(784, 300),  # in_features=784, out_features=300, 784是输入特征数，300是输出特征数\n",
    "            nn.ReLU(), # 激活函数\n",
    "            nn.Linear(300, 100),#隐藏层神经元数100\n",
    "            nn.ReLU(), # 激活函数\n",
    "            nn.Linear(100, 10),#输出层神经元数10\n",
    "        )\n",
    "\n",
    "    def forward(self, x): # 前向计算\n",
    "        # x.shape [batch size, 1, 28, 28]\n",
    "        x = self.flatten(x)  \n",
    "        # 展平后 x.shape [batch size, 784]\n",
    "        logits = self.linear_relu_stack(x)\n",
    "        # logits.shape [batch size, 10]\n",
    "        return logits #没有经过softmax,称为logits\n",
    "    \n",
    "model = NeuralNetwork()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# 看看网络结构\n",
    "model"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# 查看模型运算的tensor尺寸\n",
    "x = torch.rand(32, 1, 28, 28)\n",
    "print(x.shape)\n",
    "print(model(x).shape)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "784*300+300+300*100+100+100*10+10"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "for name, param in model.named_parameters(): # 打印模型参数\n",
    "      print(name, param.shape)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# 看看模型参数\n",
    "list(model.parameters())  # 这种方法拿到模型的所有可学习参数,requires_grad=True\n"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# model.state_dict()  # 这种方法用于保存模型参数，看能看见参数属于模型的哪一部分"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练\n",
    "\n",
    "pytorch的训练需要自行实现，包括\n",
    "1. 定义损失函数\n",
    "2. 定义优化器\n",
    "3. 定义训练步\n",
    "4. 训练"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# 1. 定义损失函数 采用交叉熵损失\n",
    "loss_fct = nn.CrossEntropyLoss() #内部先做softmax，然后计算交叉熵\n",
    "# 2. 定义优化器 采用SGD\n",
    "# Optimizers specified in the torch.optim package,随机梯度下降\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "@torch.no_grad() # 装饰器，禁止反向传播，节省内存\n",
    "def evaluating(model, dataloader, loss_fct):\n",
    "    loss_list = [] # 记录损失\n",
    "    pred_list = [] # 记录预测\n",
    "    label_list = [] # 记录标签\n",
    "    for datas, labels in dataloader:#10000/32=312\n",
    "        datas = datas.to(device) # 转到GPU\n",
    "        labels = labels.to(device) # 转到GPU\n",
    "        # 前向计算\n",
    "        logits = model(datas)\n",
    "        loss = loss_fct(logits, labels)         # 验证集损失\n",
    "        loss_list.append(loss.item()) # 记录损失\n",
    "        \n",
    "        preds = logits.argmax(axis=-1)    # 验证集预测,argmax返回最大值索引\n",
    "        # print(preds)\n",
    "        pred_list.extend(preds.cpu().numpy().tolist())#将PyTorch张量转换为NumPy数组。只有当张量在CPU上时，这个转换才是合法的\n",
    "        # print(preds.cpu().numpy().tolist())\n",
    "        label_list.extend(labels.cpu().numpy().tolist())\n",
    "        \n",
    "    acc = accuracy_score(label_list, pred_list) # 计算准确率\n",
    "    return np.mean(loss_list), acc\n"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "1875*20"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# 训练\n",
    "def training(model, train_loader, val_loader, epoch, loss_fct, optimizer, eval_step=500):\n",
    "    record_dict = {\n",
    "        \"train\": [],\n",
    "        \"val\": []\n",
    "    }\n",
    "    \n",
    "    global_step = 0\n",
    "    model.train()\n",
    "    with tqdm(total=epoch * len(train_loader)) as pbar: # 进度条 1875*20,60000/32=1875\n",
    "        for epoch_id in range(epoch): # 训练epoch次\n",
    "            # training\n",
    "            for datas, labels in train_loader: #执行次数是60000/32=1875\n",
    "                datas = datas.to(device) #datas尺寸是[batch_size,1,28,28]\n",
    "                labels = labels.to(device) #labels尺寸是[batch_size]\n",
    "                # 梯度清空\n",
    "                optimizer.zero_grad()\n",
    "                # 模型前向计算\n",
    "                logits = model(datas)\n",
    "                # 计算损失\n",
    "                loss = loss_fct(logits, labels)\n",
    "                # 梯度回传，loss.backward()会计算梯度，loss对模型参数求导\n",
    "                loss.backward()\n",
    "                # 调整优化器，包括学习率的变动等,优化器的学习率会随着训练的进行而减小，更新w,b\n",
    "                optimizer.step() #梯度是计算并存储在模型参数的 .grad 属性中，优化器使用这些存储的梯度来更新模型参数\n",
    "\n",
    "                preds = logits.argmax(axis=-1) # 训练集预测\n",
    "                acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())   # 计算准确率，numpy可以\n",
    "                loss = loss.cpu().item() # 损失转到CPU，item()取值,一个数值\n",
    "                # record\n",
    "                \n",
    "                record_dict[\"train\"].append({\n",
    "                    \"loss\": loss, \"acc\": acc, \"step\": global_step\n",
    "                }) # 记录训练集信息，每一步的损失，准确率，步数\n",
    "                \n",
    "                # evaluating\n",
    "                if global_step % eval_step == 0:\n",
    "                    model.eval() # 进入评估模式\n",
    "                    val_loss, val_acc = evaluating(model, val_loader, loss_fct)\n",
    "                    record_dict[\"val\"].append({\n",
    "                        \"loss\": val_loss, \"acc\": val_acc, \"step\": global_step\n",
    "                    })\n",
    "                    model.train() # 进入训练模式\n",
    "\n",
    "                # udate step\n",
    "                global_step += 1 # 全局步数加1\n",
    "                pbar.update(1) # 更新进度条\n",
    "                pbar.set_postfix({\"epoch\": epoch_id}) # 设置进度条显示信息\n",
    "        \n",
    "    return record_dict\n",
    "        \n",
    "\n",
    "epoch = 20 #改为40\n",
    "model = model.to(device)\n",
    "record = training(model, train_loader, val_loader, epoch, loss_fct, optimizer, eval_step=1000)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "record[\"train\"][-5:]"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "record[\"val\"][-5:]"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "#画线要注意的是损失是不一定在零到1之间的\n",
    "def plot_learning_curves(record_dict, sample_step=1000):\n",
    "    # build DataFrame\n",
    "    train_df = pd.DataFrame(record_dict[\"train\"]).set_index(\"step\").iloc[::sample_step]\n",
    "    val_df = pd.DataFrame(record_dict[\"val\"]).set_index(\"step\")\n",
    "    last_step = train_df.index[-1] # 最后一步的步数\n",
    "    # print(train_df.columns)\n",
    "    print(train_df['acc'])\n",
    "    print(val_df['acc'])\n",
    "    # plot\n",
    "    fig_num = len(train_df.columns) # 画几张图,分别是损失和准确率\n",
    "    fig, axs = plt.subplots(1, fig_num, figsize=(5 * fig_num, 5))\n",
    "    for idx, item in enumerate(train_df.columns):\n",
    "        # print(train_df[item].values)\n",
    "        axs[idx].plot(train_df.index, train_df[item], label=f\"train_{item}\")\n",
    "        axs[idx].plot(val_df.index, val_df[item], label=f\"val_{item}\")\n",
    "        axs[idx].grid() # 显示网格\n",
    "        axs[idx].legend() # 显示图例\n",
    "        axs[idx].set_xticks(range(0, train_df.index[-1], 5000)) # 设置x轴刻度\n",
    "        axs[idx].set_xticklabels(map(lambda x: f\"{int(x/1000)}k\", range(0, last_step, 5000))) # 设置x轴标签\n",
    "        axs[idx].set_xlabel(\"step\")\n",
    "    \n",
    "    plt.show()\n",
    "\n",
    "plot_learning_curves(record)  #横坐标是 steps"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 评估"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# dataload for evaluating\n",
    "\n",
    "model.eval() # 进入评估模式\n",
    "loss, acc = evaluating(model, val_loader, loss_fct)\n",
    "print(f\"loss:     {loss:.4f}\\naccuracy: {acc:.4f}\")"
   ],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
